Explanation methods¶
Deep learning models are becoming better and better at making predictions. As researchers, regulators, and users, we are also interested in asking additional questions. Namely, we would like to explain a decision in terms of the input. Where in an image is a model focusing on? What cues is the prediction based on? Ddoes it match our expectation? Can the model be trusted?
In this practical, we will explore popular methods for explaining decisions made by image classifiers:
- Simple occlusion
- Gradient norm
- Gradient x input
- GradCAM
- Integrated gradients
With a working implementation of each method, we will compare explanations qualitatively on a few sample images.
Furthermore, we will evaluate the correctness of each method quantitatively using the deletion score.
Setup¶
!pip install "jax[cuda]" -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html'
!pip install \
flax optax \
'git+https://github.com/n2cholas/jax-resnet.git' \
tensorflow-datasets \
better_exceptions
Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html Collecting jax[cuda] Downloading jax-0.6.2-py3-none-any.whl.metadata (13 kB) Collecting jaxlib<=0.6.2,>=0.6.2 (from jax[cuda]) Downloading jaxlib-0.6.2-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.3 kB) Collecting ml_dtypes>=0.5.0 (from jax[cuda]) Downloading ml_dtypes-0.5.3-cp310-cp310-macosx_10_9_universal2.whl.metadata (8.9 kB) Requirement already satisfied: numpy>=1.26 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (1.26.4) Requirement already satisfied: opt_einsum in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (3.4.0) Requirement already satisfied: scipy>=1.12 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (1.15.3) INFO: pip is looking at multiple versions of jax[cuda] to determine which version is compatible with other requirements. This could take a while. Collecting jax[cuda] Downloading jax-0.6.1-py3-none-any.whl.metadata (13 kB) Collecting jaxlib<=0.6.1,>=0.6.1 (from jax[cuda]) Downloading jaxlib-0.6.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.2 kB) Collecting jax[cuda] Downloading jax-0.6.0-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.6.0,>=0.6.0 (from jax[cuda]) Downloading jaxlib-0.6.0-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.2 kB) Collecting jax[cuda] Downloading jax-0.5.3-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.5.3,>=0.5.3 (from jax[cuda]) Downloading jaxlib-0.5.3-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.2 kB) Collecting jax[cuda] Downloading jax-0.5.2-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.5.2,>=0.5.1 (from jax[cuda]) Downloading jaxlib-0.5.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (978 bytes) Collecting jax[cuda] Downloading jax-0.5.1-py3-none-any.whl.metadata (22 kB) Downloading jax-0.5.0-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.5.0,>=0.5.0 (from jax[cuda]) Downloading jaxlib-0.5.0-cp310-cp310-macosx_11_0_arm64.whl.metadata (978 bytes) Collecting jax[cuda] Downloading jax-0.4.38-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.4.38,>=0.4.38 (from jax[cuda]) Downloading jaxlib-0.4.38-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.0 kB) INFO: pip is still looking at multiple versions of jax[cuda] to determine which version is compatible with other requirements. This could take a while. Collecting jax[cuda] Downloading jax-0.4.37-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.4.37,>=0.4.36 (from jax[cuda]) Downloading jaxlib-0.4.36-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.0 kB) Collecting jax[cuda] Downloading jax-0.4.36-py3-none-any.whl.metadata (22 kB) Downloading jax-0.4.35-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.4.35,>=0.4.34 (from jax[cuda]) Downloading jaxlib-0.4.35-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes) Downloading jaxlib-0.4.34-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes) Collecting jax[cuda] Downloading jax-0.4.34-py3-none-any.whl.metadata (22 kB) Requirement already satisfied: ml-dtypes>=0.2.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (0.3.2) Downloading jax-0.4.33-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.4.33,>=0.4.33 (from jax[cuda]) Downloading jaxlib-0.4.33-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes) INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. See https://pip.pypa.io/warnings/backtracking for guidance. If you want to abort this run, press Ctrl + C. Collecting jax[cuda] Downloading jax-0.4.31-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.4.31,>=0.4.30 (from jax[cuda]) Downloading jaxlib-0.4.31-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes) Collecting jax[cuda] Using cached jax-0.4.30-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.4.30,>=0.4.27 (from jax[cuda]) Downloading jaxlib-0.4.30-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.0 kB) Collecting jax[cuda] Using cached jax-0.4.29-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.28-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.27-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.26-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.25-py3-none-any.whl.metadata (24 kB) Using cached jax-0.4.24-py3-none-any.whl.metadata (24 kB) Using cached jax-0.4.23-py3-none-any.whl.metadata (24 kB) Using cached jax-0.4.22-py3-none-any.whl.metadata (24 kB) Using cached jax-0.4.21-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.20-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.19-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.18-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.17-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.16-py3-none-any.whl.metadata (29 kB) Using cached jax-0.4.14.tar.gz (1.3 MB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Using cached jax-0.4.13.tar.gz (1.3 MB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Using cached jax-0.4.12.tar.gz (1.3 MB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Using cached jax-0.4.11.tar.gz (1.3 MB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Using cached jax-0.4.10.tar.gz (1.3 MB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Using cached jax-0.4.9.tar.gz (1.3 MB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Using cached jax-0.4.8.tar.gz (1.2 MB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Using cached jax-0.4.7.tar.gz (1.2 MB) Preparing metadata (setup.py) ... done Using cached jax-0.4.6.tar.gz (1.2 MB) Preparing metadata (setup.py) ... done Using cached jax-0.4.5.tar.gz (1.2 MB) Preparing metadata (setup.py) ... done Using cached jax-0.4.4.tar.gz (1.2 MB) Preparing metadata (setup.py) ... done Using cached jax-0.4.3.tar.gz (1.2 MB) Preparing metadata (setup.py) ... done Using cached jax-0.4.2.tar.gz (1.2 MB) Preparing metadata (setup.py) ... done Using cached jax-0.4.1.tar.gz (1.2 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.25.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Requirement already satisfied: typing_extensions in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (4.15.0) Using cached jax-0.3.24.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.23.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Requirement already satisfied: absl-py in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (2.3.1) Requirement already satisfied: etils[epath] in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (1.13.0) Using cached jax-0.3.22.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.21.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.20.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.19.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.17.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.16.tar.gz (1.0 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.15.tar.gz (1.0 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.14.tar.gz (990 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.13.tar.gz (951 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.12.tar.gz (947 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.11.tar.gz (947 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.10.tar.gz (939 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.9.tar.gz (937 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.8.tar.gz (935 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.7.tar.gz (944 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.6.tar.gz (936 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.5.tar.gz (946 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.4.tar.gz (924 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.3.tar.gz (924 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.2.tar.gz (926 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.1.tar.gz (912 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.0.tar.gz (896 kB) Preparing metadata (setup.py) ... done Using cached jax-0.2.28.tar.gz (887 kB) Preparing metadata (setup.py) ... done Using cached jax-0.2.27.tar.gz (873 kB) Preparing metadata (setup.py) ... done Using cached jax-0.2.26.tar.gz (850 kB) Preparing metadata (setup.py) ... done Using cached jax-0.2.25.tar.gz (786 kB) Preparing metadata (setup.py) ... done Using cached jax-0.2.24.tar.gz (786 kB) Preparing metadata (setup.py) ... done Using cached jax-0.2.22.tar.gz (776 kB) Preparing metadata (setup.py) ... done WARNING: jax 0.2.22 does not provide the extra 'cuda' Building wheels for collected packages: jax DEPRECATION: Building 'jax' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'jax'. Discussion can be found at https://github.com/pypa/pip/issues/6334 Building wheel for jax (setup.py) ... done Created wheel for jax: filename=jax-0.2.22-py3-none-any.whl size=890324 sha256=d8f8654332391b7d5273b9106ae0648bedbd3b86ff6a2ad821d0041fb259fafa Stored in directory: /Users/silpasoninallacheruvu/Library/Caches/pip/wheels/07/6c/f6/11dc726435faa88188b1f08d34780c161bb9eb966f3a5a01a7 Successfully built jax Installing collected packages: jax Successfully installed jax-0.2.22 Collecting git+https://github.com/n2cholas/jax-resnet.git Cloning https://github.com/n2cholas/jax-resnet.git to /private/var/folders/mj/bwns80rs3psccqv0gz4hh8zc0000gn/T/pip-req-build-wrhzyzg8 Running command git clone --filter=blob:none --quiet https://github.com/n2cholas/jax-resnet.git /private/var/folders/mj/bwns80rs3psccqv0gz4hh8zc0000gn/T/pip-req-build-wrhzyzg8 Resolved https://github.com/n2cholas/jax-resnet.git to commit 5b00735aa0a68ec239af4a728ad4a596c1b551f6 Preparing metadata (setup.py) ... done Collecting flax Downloading flax-0.10.7-py3-none-any.whl.metadata (11 kB) Collecting optax Downloading optax-0.2.6-py3-none-any.whl.metadata (7.6 kB) Requirement already satisfied: tensorflow-datasets in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (4.9.9) Collecting better_exceptions Using cached better_exceptions-0.3.3-py3-none-any.whl.metadata (466 bytes) Requirement already satisfied: jax in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax-resnet==0.0.4) (0.2.22) Collecting jaxlib (from jax-resnet==0.0.4) Using cached jaxlib-0.6.2-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.3 kB) Collecting jax (from jax-resnet==0.0.4) Using cached jax-0.6.2-py3-none-any.whl.metadata (13 kB) Collecting msgpack (from flax) Downloading msgpack-1.1.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (8.4 kB) Collecting orbax-checkpoint (from flax) Downloading orbax_checkpoint-0.11.25-py3-none-any.whl.metadata (2.3 kB) Collecting tensorstore (from flax) Downloading tensorstore-0.1.77-cp310-cp310-macosx_11_0_arm64.whl.metadata (21 kB) Requirement already satisfied: rich>=11.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (14.1.0) Requirement already satisfied: typing_extensions>=4.2 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (4.15.0) Requirement already satisfied: PyYAML>=5.4.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (6.0.3) Collecting treescope>=0.1.7 (from flax) Downloading treescope-0.1.10-py3-none-any.whl.metadata (6.6 kB) Requirement already satisfied: absl-py>=0.7.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from optax) (2.3.1) Collecting chex>=0.1.87 (from optax) Using cached chex-0.1.90-py3-none-any.whl.metadata (18 kB) Requirement already satisfied: numpy>=1.18.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from optax) (1.26.4) Requirement already satisfied: dm-tree in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (0.1.9) Requirement already satisfied: etils>=1.6.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (1.13.0) Requirement already satisfied: immutabledict in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (4.2.1) Requirement already satisfied: promise in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (2.3) Requirement already satisfied: protobuf>=3.20 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (4.21.12) Requirement already satisfied: psutil in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (7.1.0) Requirement already satisfied: pyarrow in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (21.0.0) Requirement already satisfied: requests>=2.19.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (2.32.5) Requirement already satisfied: simple_parsing in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (0.1.7) Requirement already satisfied: tensorflow-metadata in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (1.17.2) Requirement already satisfied: termcolor in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (3.1.0) Requirement already satisfied: toml in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (0.10.2) Requirement already satisfied: tqdm in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (4.67.1) Requirement already satisfied: wrapt in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (1.17.3) Collecting toolz>=0.9.0 (from chex>=0.1.87->optax) Using cached toolz-1.0.0-py3-none-any.whl.metadata (5.1 kB) Requirement already satisfied: einops in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (0.8.1) Requirement already satisfied: fsspec in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (2025.9.0) Requirement already satisfied: importlib_resources in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (6.5.2) Requirement already satisfied: zipp in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (3.23.0) Collecting ml_dtypes>=0.5.0 (from jax->jax-resnet==0.0.4) Using cached ml_dtypes-0.5.3-cp310-cp310-macosx_10_9_universal2.whl.metadata (8.9 kB) Requirement already satisfied: opt_einsum in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax->jax-resnet==0.0.4) (3.4.0) Requirement already satisfied: scipy>=1.12 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax->jax-resnet==0.0.4) (1.15.3) Requirement already satisfied: charset_normalizer<4,>=2 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (3.4.3) Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (3.10) Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (2.5.0) Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (2025.8.3) Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from rich>=11.1->flax) (4.0.0) Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from rich>=11.1->flax) (2.19.2) Requirement already satisfied: mdurl~=0.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax) (0.1.2) Requirement already satisfied: attrs>=18.2.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from dm-tree->tensorflow-datasets) (25.3.0) Requirement already satisfied: nest_asyncio in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from orbax-checkpoint->flax) (1.6.0) Collecting aiofiles (from orbax-checkpoint->flax) Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB) Collecting humanize (from orbax-checkpoint->flax) Using cached humanize-4.13.0-py3-none-any.whl.metadata (7.8 kB) Collecting simplejson>=3.16.0 (from orbax-checkpoint->flax) Downloading simplejson-3.20.2-cp310-cp310-macosx_11_0_arm64.whl.metadata (3.4 kB) Requirement already satisfied: six in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from promise->tensorflow-datasets) (1.17.0) Requirement already satisfied: docstring-parser<1.0,>=0.15 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from simple_parsing->tensorflow-datasets) (0.17.0) Downloading flax-0.10.7-py3-none-any.whl (456 kB) Downloading optax-0.2.6-py3-none-any.whl (367 kB) Using cached better_exceptions-0.3.3-py3-none-any.whl (11 kB) Using cached chex-0.1.90-py3-none-any.whl (101 kB) Downloading jax-0.6.2-py3-none-any.whl (2.7 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.7/2.7 MB 11.0 MB/s 0:00:00 11.5 MB/s eta 0:00:01 Downloading jaxlib-0.6.2-cp310-cp310-macosx_11_0_arm64.whl (54.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.3/54.3 MB 11.7 MB/s 0:00:04a 0:00:01[36m0:00:01:01 Downloading ml_dtypes-0.5.3-cp310-cp310-macosx_10_9_universal2.whl (667 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 667.4/667.4 kB 9.6 MB/s 0:00:00 Using cached toolz-1.0.0-py3-none-any.whl (56 kB) Downloading treescope-0.1.10-py3-none-any.whl (182 kB) Downloading msgpack-1.1.1-cp310-cp310-macosx_11_0_arm64.whl (78 kB) Downloading orbax_checkpoint-0.11.25-py3-none-any.whl (563 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 563.1/563.1 kB 3.2 MB/s 0:00:00 Downloading simplejson-3.20.2-cp310-cp310-macosx_11_0_arm64.whl (76 kB) Downloading tensorstore-0.1.77-cp310-cp310-macosx_11_0_arm64.whl (13.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 8.2 MB/s 0:00:01 eta 0:00:01[36m0:00:01 Downloading aiofiles-24.1.0-py3-none-any.whl (15 kB) Using cached humanize-4.13.0-py3-none-any.whl (128 kB) Building wheels for collected packages: jax-resnet DEPRECATION: Building 'jax-resnet' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'jax-resnet'. Discussion can be found at https://github.com/pypa/pip/issues/6334 Building wheel for jax-resnet (setup.py) ... done Created wheel for jax-resnet: filename=jax_resnet-0.0.4-py2.py3-none-any.whl size=11972 sha256=c05f1546fe2444af8d534b2fceb6ffe850111352df782eb3a938788e235cbc85 Stored in directory: /private/var/folders/mj/bwns80rs3psccqv0gz4hh8zc0000gn/T/pip-ephem-wheel-cache-xxf9xpr2/wheels/2b/57/8c/a9e9b5ae55d9dfc4466c910140d5625f44eb779908cc868b2d Successfully built jax-resnet Installing collected packages: better_exceptions, treescope, toolz, simplejson, msgpack, ml_dtypes, humanize, aiofiles, tensorstore, jaxlib, jax, orbax-checkpoint, chex, optax, flax, jax-resnet Attempting uninstall: ml_dtypes Found existing installation: ml-dtypes 0.3.2 Uninstalling ml-dtypes-0.3.2: Successfully uninstalled ml-dtypes-0.3.2 Attempting uninstall: jax━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━ 9/16 [jaxlib]es] Found existing installation: jax 0.2.22╸━━━━━━━━━━━━━━━━━ 9/16 [jaxlib] Uninstalling jax-0.2.22:━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━ 9/16 [jaxlib] Successfully uninstalled jax-0.2.220m╸━━━━━━━━━━━━━━━━━ 9/16 [jaxlib] ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 [jax-resnet]0m 14/16 [flax] [optax]checkpoint] ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. tensorflow 2.16.2 requires ml-dtypes~=0.3.1, but you have ml-dtypes 0.5.3 which is incompatible. Successfully installed aiofiles-24.1.0 better_exceptions-0.3.3 chex-0.1.90 flax-0.10.7 humanize-4.13.0 jax-0.6.2 jax-resnet-0.0.4 jaxlib-0.6.2 ml_dtypes-0.5.3 msgpack-1.1.1 optax-0.2.6 orbax-checkpoint-0.11.25 simplejson-3.20.2 tensorstore-0.1.77 toolz-1.0.0 treescope-0.1.10
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
import tensorflow as tf
tf.get_logger().setLevel("WARNING")
tf.config.experimental.set_visible_devices([], "GPU")
from collections import defaultdict
from functools import partial
from typing import Sequence
import flax.core
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax_resnet
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import optax
import pandas as pd
import sklearn.metrics
import tabulate
import tensorflow_datasets as tfds
import torch
import tqdm
from flax.training.train_state import TrainState
from IPython.display import display
from jax import jit, vmap
RED = np.array([1.0, 0, 0])
BLUE = np.array([0, 0, 1.0])
@jax.jit
def normalize_zero_one(x):
"""Normalize a vector between 0 and 1."""
res = (x - x.min()) / (x.max() - x.min())
res = jnp.clip(res, a_min=0, a_max=1)
return res
@jax.jit
def normalize_max(x):
"""Normalize a vector between -1 and 1."""
res = x / jnp.abs(x).max()
res = jnp.clip(res, a_min=-1, a_max=1)
return res
@jax.jit
def blend(a, b, alpha: float):
"""Blend two float-valued images"""
return (1 - alpha) * a + alpha * b
Dataset¶
For simplicity, we will use the small ImageNette dataset that contains 10 easy-to-classify categories from ImageNet.
Here we load the dataset and show a few images that will be used throughout this notebook.
CLASS_NAMES = [
"tench",
"English springer",
"cassette player",
"chain saw",
"church",
"French horn",
"garbage truck",
"gas pump",
"golf ball",
"parachute",
]
def show_images(images, labels=None, logits=None, ncols=4, width_one_img_inch=3.0):
B, H, W, *_ = images.shape
nrows = int(np.ceil(B / ncols))
fig, axs = plt.subplots(
nrows,
ncols,
figsize=width_one_img_inch * np.array([1, H / W]) * np.array([ncols, nrows]),
sharex=True,
sharey=True,
squeeze=False,
facecolor="white",
)
for b in range(B):
ax = axs.flat[b]
ax.imshow(images[b])
if labels is not None:
ax.set_title(CLASS_NAMES[labels[b]])
if logits is not None:
pred = logits[b].argmax()
prob = nn.softmax(logits[b])[pred]
color = (
"blue" if labels is None else ("green" if labels[b] == pred else "red")
)
p = mpl.patches.Patch(color=color, label=f"{prob:.2%} {CLASS_NAMES[pred]}")
ax.legend(handles=[p])
fig.tight_layout()
display(fig)
plt.close(fig)
def resize(image, label):
image = tf.image.resize_with_pad(image, 224, 224)
return image / 255.0, label
ds_builder = tfds.builder("imagenette/320px-v2", data_dir=".")
ds_builder.download_and_prepare()
total_images = ds_builder.info
print(f"Total images: {total_images}")
ds = ds_builder.as_dataset(split="train", batch_size=None, as_supervised=True)
ds = ds.map(resize)
ds = ds.batch(8)
ds = tfds.as_numpy(ds)
viz_batch = next(iter(ds))
images, labels = viz_batch
show_images(images, labels)
Total images: tfds.core.DatasetInfo(
name='imagenette',
full_name='imagenette/320px-v2/1.0.0',
description="""
Imagenette is a subset of 10 easily classified classes from the Imagenet
dataset. It was originally prepared by Jeremy Howard of FastAI. The objective
behind putting together a small version of the Imagenet dataset was mainly
because running new ideas/algorithms/experiments on the whole Imagenet take a
lot of time.
This version of the dataset allows researchers/practitioners to quickly try out
ideas and share with others. The dataset comes in three variants:
* Full size
* 320 px
* 160 px
Note: The v2 config correspond to the new 70/30 train/valid split (released in
Dec 6 2019).
""",
config_description="""
320px variant.
""",
homepage='https://github.com/fastai/imagenette',
data_dir='imagenette/320px-v2/1.0.0',
file_format=tfrecord,
download_size=325.84 MiB,
dataset_size=332.71 MiB,
features=FeaturesDict({
'image': Image(shape=(None, None, 3), dtype=uint8),
'label': ClassLabel(shape=(), dtype=int64, num_classes=10),
}),
supervised_keys=('image', 'label'),
disable_shuffling=False,
nondeterministic_order=False,
splits={
'train': <SplitInfo num_examples=9469, num_shards=2>,
'validation': <SplitInfo num_examples=3925, num_shards=1>,
},
citation="""@misc{imagenette,
author = "Jeremy Howard",
title = "imagenette",
url = "https://github.com/fastai/imagenette/"
}""",
)
def load_resnet(size):
"""Load a resnet model and return resnet_logits_fn and its variables.
Returns:
logits_fn: a jitted function that given one image applies
the resnet model and returns the max logit
value and the logits vector
variables: resnet variables to use with logits_fn
"""
def logits_fn(variables, img):
# img: [H, W, C], float32 in range [0, 1]
#print(f"img shape in logits_fn:{img.shape}")
assert img.ndim == 3
img = normalize_for_resnet(img)
logits = model.apply(variables, img[None, ...])[0]
logits = imagenet_to_imagenette_logits(logits)
return logits.max(), logits
ResNet, variables = jax_resnet.pretrained_resnet(size)
model = ResNet()
logits_fn = jax.jit(logits_fn)
return logits_fn, variables
def normalize_for_resnet(image):
mean = jnp.array([0.485, 0.456, 0.406])
std = jnp.array([0.229, 0.224, 0.225])
return (image - mean) / std
def imagenet_to_imagenette_logits(logits):
"""Select the 10 imagenette classes from the 1000 imagenet classes."""
return logits[..., [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]]
logits_fn, variables = load_resnet(size=18)
images, labels = viz_batch
print(f"images:{images.shape}")
_, logits = jax.vmap(logits_fn, (None, 0))(variables, images)
print(f"logits shape: {logits.shape}")
print(f"images shape: {(images.shape[0], len(CLASS_NAMES))}")
assert logits.shape == (images.shape[0], len(CLASS_NAMES))
show_images(images, labels, logits)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
images:(20, 224, 224, 3) logits shape: (20, 10) images shape: (20, 10)
Pretrained ResNet¶
We will focus on a ResNet 18 model for the explanations which has been ported from PyTorch thanks to this repo.
The simplest way to load and run a ResNet model using jax_resnet is:
ResNet, variables = jax_resnet.pretrained_resnet(size)
model = ResNet()
img = jnp.zeros(224, 224, 3) # [H, W, C]
logits = model.apply(variables, img[None, ...])[0] # [1000]
Task 1¶
Here we load a pre-trained model and prepare it for our purposes. We want the following:
- The function should operate on a single image instead of a batch. Altough counterintuitive, this will make it easier to reason about explanations later and is more in tune with the philosophy of jax.
- The function should take care of normalizing the image with mean
[0.485, 0.456, 0.406]and std[0.229, 0.224, 0.225]as done for the PyTorch models that this model was converted from. Refer to torchvision.transforms.Normalize for an example. - Select out of the 1000 ImageNet classes the 10 ImageNette classes that we are interested in.
- The function should return the largest element of the 10-dimensional logits vector, since later on we'll often compute gradients of it. The full logits vector should also be returned for prediction and visualization purposes.
Complete the function logits_fn returned by load_resnet so that it fullfills the requirements above.
Upon executing the cell you should see 7/8 correct predictions with almost-certain confidence.
Explanation methods¶
Occlusion¶
The simplest explanation method consists in removing patches of the input image and measuring the effect on prediction confidence. Specifically, we want to measure the drop (or increase) in confidence in the predicted class between the original non-occluded image and an occluded version.
We will use a single square patch of fixed size that is scanned over the entire image without overlap, altough it would be possible to come up with more advanced patterns of occlusion.
Task 2¶
Complete the function prepare_occlusions that takes in a single image of shape [H, W, 3]
and outputs a batch of images of shape [S, S, H, W, 3] where the image at [i, j] contains
a black patch of size [H/S, W/S] whose top-left corner is placed at [i*H/S, j*W/S].
Explained with a drawing:
imgs[i, j] =
j*W/S
┌───────────┬────┬────┐
│ | | │
│ | | │
i*H/S ├ ─ ─ ─ ─ ─ ┼────┤ │
│ │####│ │
│ │####│ │
├ ─ ─ ─ ─ ─ ┴────┘ │
│ │
│ │
│ │
│ │
│ │
└─────────────────────┘
Remember that in jax arrays can not be modified in-place.
Use at[].set() instead:
x[idx] = y # Bad
x = x.at[idx].set(y) # Good
Once the missing lines in prepare_occlusions are filled in, visualize the resulting batch of partially-occluded images to check your implementation.
def prepare_occlusions(img, steps: int):
H, W, _ = img.shape
imgs = jnp.tile(img, (steps, steps, 1, 1, 1))
for i in range(0, steps):
for j in range(0, steps):
imgs = imgs.at[i, j, int(i*H/steps):int((i+1)*H/steps), int(j*W/steps):int((j+1)*W/steps), :].set(0)
print(f"shape of imgs:{imgs.shape}")
# imgs: [steps, steps, H, W, 3]
return imgs
prepare_occlusions = jax.jit(prepare_occlusions, static_argnames="steps")
show_images(
prepare_occlusions(viz_batch[0][0], steps=3).reshape(-1, 224, 224, 3),
ncols=3,
width_one_img_inch=1.5,
)
shape of imgs:(3, 3, 224, 224, 3)
Task 3¶
Using prepare_occlusions implemented above, complete the missing lines in occlusion_fn following to this pseudo-code:
probs = f(img)
idx = argmax(probs)
imgs = prepare_occlusions(img)
relevance[i, j] = f(img)[idx] - f(imgs[i, j])[idx]
relevance = resize(relevance, img.shape)
With a working implementation, the code below will show positive and negative attributions for eight images. Positive attribution is shown as a red overlay, while negative attribution is shown in blue (almost invisible except for the last image).
Note: jit and vmap take care of speeding up and vectorizing occlusion_fn
so that it works on a batch of images.
You will see them used as wrappers or decorators throughout the notebook.
Tips:
- you want to compute how much the probability of the original prediction drops, apply
softmaxto the output oflogits_fnto get probabilities and select the right class withidx - apply vmap twice to
logits_fnto vectorize it over the two extra axes added byprepare_occlusions, you don't need two nested for loops - use
jax.image.resizewithmethod="bilinear"to resize the heatmap to the original size - use
normalize_maxto rescale the attributions to a range that works well with the visualization code
def occlusion_fn(logits_fn, variables, img, steps: int):
H, W, _ = img.shape
_, logits_orig = logits_fn(variables, img)
probs = nn.softmax(logits_orig)
#print(f"probs:{probs.shape}")
idx = logits_orig.argmax()
#print(f"idx:{idx.shape}")
imgs = prepare_occlusions(img, steps)
logits_occ_fn = jax.vmap(
jax.vmap(logits_fn, (None,0)),
(None,0)
)
_, logits_occ = logits_occ_fn(variables, imgs)
probs_occ = nn.softmax(logits_occ, axis=-1)
relevance = probs[idx] - probs_occ[..., idx]
#print(f"relevance:{relevance.shape}")
relevance = jax.image.resize(relevance, (H, W), method="bilinear")
attrib = normalize_max(relevance)
#print(f"relevance:{attrib.shape}")
# logits_orig: [num_classes]
# attrib: [H, W]
return logits_orig, attrib
occlusion_fn = jax.jit(occlusion_fn, static_argnames=["logits_fn", "steps"])
occlusion_fn = jax.vmap(occlusion_fn, in_axes=(None, None, 0, None))
images, labels = viz_batch
logits_fn, variables = load_resnet(size=18)
logits, relevance = occlusion_fn(logits_fn, variables, images, 6)
images = blend(images, RED, jnp.clip(relevance, a_min=0)[..., None])
images = blend(images, BLUE, -jnp.clip(relevance, a_max=0)[..., None])
show_images(images, labels, logits)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
Grad norm (sensitivity)¶
An important tool for decision explanation is the gradient of the prediction function with respect to the input variable evaluated at the input image. Intuitively, the gradient expresses how much a change in the input would affect the prediction (actually the pre-softmax confidence). By evaluating the gradient at the input image, we can estimate the relevance $R_i$ of each pixel $i$.
$$ \begin{align} X &\in \mathbb{R}^{D} \\ p &= f(X) \\ R_i &= \nabla f(X_i) \end{align} $$
Since our models operate on images, the gradient w.r.t. an image will have shape [H, W, 3].
For ease of visualization, we will compute the norm of the gradient at each pixel location and visualize it as a heatmap of shape [H, W].
Task 4¶
Complete the missing lines of grad_norm_fn so that given a single input image it returns the associated logits and the pixel-wise norm of the gradient of the most confident prediction.
Also, since we want to overlay the explanation to the image, make sure to scale the results in the range [0, 1].
Tips:
- The function
logits_fnprepared in task 1 returns the maximum logit as its first return value. - In jax one can use
jax.value_and_gradto decorate a function so that both the value and its gradient are returned. - The function
jax.value_and_gradcan also take an extra parameterhas_auxto indicate that the original function returns more than one value and that those extra values should be returned by the decorated function too. Example:def foo(a, x): y = jnp.exp(x**2) - jnp.sin(a @ x) return y.sum(), y foo_vg = jax.value_and_grad(foo, argnums=1, has_aux=True) (y_sum, y), grad_x = foo_vg(a, x)
- use
normalize_maxto rescale the attributions to a range that works well with the visualization code
def grad_norm_fn(logits_fn, variables, img):
#print(f"img in grad_norm_fn:{img.shape}")
H, W, _ = img.shape
logits_vg_fn = jax.value_and_grad(logits_fn, argnums=1, has_aux=True)
(_, logits), grads = logits_vg_fn(variables, img)
#print(f"grads:{grads.shape}")
heat = jnp.linalg.norm(grads, axis=-1)
grad = normalize_max(heat)
#print(f"grad:{grad.shape}")
# logits: [num_classes]
# grad: [H, W]
return logits, grad
grad_norm_fn = jax.jit(grad_norm_fn, static_argnames=["logits_fn"])
grad_norm_fn = jax.vmap(grad_norm_fn, in_axes=(None, None, 0))
images, labels = viz_batch
logits_fn, variables = load_resnet(size=18)
logits, relevance = grad_norm_fn(logits_fn, variables, images)
show_images(
# images * relevance[..., None],
blend(images, RED, relevance[..., None]),
labels,
logits,
)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
Grad x input¶
To increase the sharpness of the explanations it's possible to multiply the value of the gradient with the corresponding input. Intuitively, the gradient expresses the importance of a certain feature and is now rescaled by how much that feature is present.
$$R_i = X_i \cdot \nabla f(X_i)$$
Task 5¶
Modify grad_norm_fn so that the gradient is multiplied with the image before computing the norm.
def grad_x_input_fn(logits_fn, variables, img):
H, W, _ = img.shape
logits_vg_fn = jax.value_and_grad(logits_fn, argnums=1, has_aux=True)
(_, logits), grads = logits_vg_fn(variables, img)
#print(f"grads:{grads.shape}")
grads_x = img * grads
heat_x = jnp.linalg.norm(grads_x, axis=-1)
grad = normalize_max(heat_x)
#print(f"grad:{grad.shape}")
# logits: [num_classes]
# grad: [H, W]
return logits, grad
grad_x_input_fn = jax.vmap(grad_x_input_fn, in_axes=(None, None, 0))
grad_x_input_fn = jax.jit(grad_x_input_fn, static_argnames=["logits_fn"])
images, labels = viz_batch
logits_fn, variables = load_resnet(size=18)
logits, relevance = grad_x_input_fn(logits_fn, variables, images)
show_images(
# images * relevance[..., None],
blend(images, RED, relevance[..., None]),
labels,
logits,
)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
Integrated gradients¶
As observed, gradient-based explanations appear very noisy. This is because gradients can only describe what happens in the local neighborhood of the input image when a pixel is changed by a small quantity, therefore:
- some pixels might be very important for the prediction (e.g. the color of a flower), but the local gradient might be saturated (e.g. different shades of yellow will all give the same confidence in "sunflower") and therefore those pixels will not be marked as relevant.
- a small step in the direction of the gradient might increase the prediction confidence, but a slightly larger step might decrease it further and an even slighly larger step might increase it again.
The integrated gradients method proposes to address this issue by aggregating gradients along a linear path between the input image and a baseline (usually black). By considering a path, the noise associated to local gradients is reduced. Also, using a baseline image allows to express the explanation in relative terms rather than absolute.
By expressing the path as $\gamma(\alpha) = B + \alpha(X-B)$, the method can be expressed as: $$ R_i = \int_0^1 \frac{\partial f(\gamma(\alpha))}{\partial\gamma_i(\alpha)} \frac{\partial\gamma_i(\alpha)}{\partial\alpha} \ d\alpha. $$
Which can be approximated as: $$ R_i \approx (X_i - B_i) \frac{1}{M} \sum_{m=1}^M \frac{\partial f\left(B + m/M(X-B)\right)}{\partial X_i}. $$
Where $B$ indicates the black baseline, $M$ is the number of steps for approximating the path integral.
The function below computes all intermediate images between an input img and a black baseline.
@partial(jax.jit, static_argnames=["steps"])
def prepare_integrated_gradients(img, steps: int):
assert img.ndim == 3
return img[None, :, :, :] * jnp.linspace(1, 0, num=steps)[:, None, None, None]
image = viz_batch[0][0]
images = prepare_integrated_gradients(image, steps=8).reshape(-1, 224, 224, 3)
show_images(images, width_one_img_inch=2)
Task 6¶
Complete the function integrated_grad_fn so that:
- a single image is taken as input
- a prediction is made on the input image to determine the predicted class
- a batch of progressively darker images is prepared with
prepare_integrated_gradients - for each image in the batch, the gradients of the logit at
idxis computed - the path integral is approximated using a finite sum
According to the official implementation, only positive attributions are considered and attributions are averaged per pixel.
Tips:
- Store the index of the most-confident prediction for the input image as
idxbecause we need to refer to it when computing gradients - At each intermediate step you don't want the gradient
max_logit, i.e. the first output oflogits_fn, which you would get fromgrad(logits_fn). Instead you want the gradient of theidx-th element oflogits, i.e. the second output. Define a local function or a lambda and callgradon that. - You don't need for loops, use
vmap
def integrated_grad_fn(logits_fn, variables, img, steps: int):
H, W, _ = img.shape
# model's predicted class
_, logits_orig = logits_fn(variables, img)
idx = logits_orig.argmax()
#print(f"idx:{idx.shape}")
baseline = jnp.zeros_like(img)
images = prepare_integrated_gradients(img, steps).reshape(-1, H, W, 3)
#print(f"images:{images.shape}")
_, logits = jax.vmap(logits_fn, (None, 0))(variables, images)
# function to call grad on idx-th element of logits
def grads_idx_fn(variables, img_):
logit_max, logit = logits_fn(variables, img_)
val = logit[idx]
return val, logit_max
value_and_grad_fn = jax.value_and_grad(grads_idx_fn, argnums=1, has_aux=True)
(_,_), grads = jax.vmap(lambda im: value_and_grad_fn(variables, im), in_axes=0)(images)
avg_grads = grads.mean(axis=0)
ig = (img - baseline) * avg_grads
heat = jnp.linalg.norm(ig, axis=-1)
grads = normalize_max(heat)
#print(f"logits_orig:{logits_orig.shape}")
#print(f"grads:{grads.shape}")
# logits: [num_classes]
# grads: [H, W]
return logits_orig, grads
integrated_grad_fn = jax.jit(integrated_grad_fn, static_argnames=["logits_fn", "steps"])
integrated_grad_fn = jax.vmap(integrated_grad_fn, in_axes=(None, None, 0, None))
images, labels = viz_batch
logits_fn, variables = load_resnet(size=18)
logits, relevance = integrated_grad_fn(logits_fn, variables, images, 25)
show_images(
# images * relevance[..., None],
blend(images, RED, relevance[..., None]),
labels,
logits,
)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
GradCAM¶
GradCAM decomposes the ResNet model in two blocks: a CNN backbone and a linear classifier, separated by global average pooling. $$ \begin{align} X &\in \mathbb{R}^{H\times W\times 3} \\ A &= \text{Backbone}(X) \in \mathbb{R}^{H'\times W'\times K}\\ Y &= \text{Linear}(\text{GAP}(A)) \in \mathbb{R}^C \end{align} $$
The main idea is to consider the gradient of the activations before global average pooling and use them to rescale the intensity of the associated feature maps. Specifically, if the predicted class is $c$, the scaling factor for the $k$-th feature map is: $$ \alpha_c^k = \frac{1}{H'W'} \sum_i^{H'}\sum_j^{W'} \frac{\partial Y^c}{\partial A^k_{i,j}} $$
The feature maps are then combined into a sized-down attribution as: $$ R_c = \text{ReLU}\left( \sum_k^K \alpha_c^k A^k\right) \in \mathbb{R}^{H'\times W'} $$
Finally, the relevance heatmap is resized to match the input image. Compared to the gradient-based methods above, GradCAM produces much smoother heatmaps thanks to this upsampling operation.
Task 7¶
Implement the missing parts of grad_cam_fn:
- First, process the image through the backbone
- Then, process the features through global average pooling the through the classifier. Remember that you'll need the gradients w.r.t. these features.
- Once you have the gradients, combine them with the features as indicated above and resize the relevance to the same size of the input image.
- As usual, return both the logits vector and the relevance matrix.
Tips:
- You will need to apply the backbone and the classifier separately.
The function
load_resnet_for_grad_camtakes care of splitting the model and its variables for you.
They are returned as two dictionaries, both containing the keysbackboneandgap_cls. - When performing complex sums and products
jnp.einsumcan drastically simplify the amount of error-prone reshaping code required.
def grad_cam_fn(fns, variables, img):
H, W, _ = img.shape
backbone_fn = fns["backbone"]
gap_classifier_fn = fns["gap_cls"]
backbone_vars = variables["backbone"]
gap_classifier_vars = variables["gap_cls"]
# apply image through backbone
features = backbone_fn(backbone_vars, img)
_, logits = gap_classifier_fn(gap_classifier_vars, features)
# fix the target class c
c = jnp.argmax(logits)
# scalar logit function for class c
def class_logit_fn(vars_, feats_):
_, logits = gap_classifier_fn(vars_, feats_)
# scalar Y^c
return logits[c]
# gradients wrt features for class c
vgf = jax.value_and_grad(class_logit_fn, argnums=1, has_aux=False)
_, grads = vgf(gap_classifier_vars, features)
alpha = grads.mean(axis=(0,1))
#print("features:", features.shape)
#print("grads:", grads.shape)
#print("alpha:", alpha.shape)
relevance = jnp.einsum("hwc,c->hw", features, alpha)
relevance = jnp.maximum(relevance, 0)
#print("relevance:", relevance.shape)
# resize to input image size
relevance_resized = jax.image.resize(
relevance, (H, W), method="bilinear"
)
relevance_resized = normalize_max(relevance_resized)
#print("relevance_resized:", relevance_resized.shape)
# logits: [num_classes]
# grad: [H, W]
return logits, relevance_resized
def load_resnet_for_grad_cam(size):
@jax.jit
def backbone_fn(variables, img):
# img: [H, W, C], float32 in range [0, 1]
# feats: [h, w, c], float32
img = normalize_for_resnet(img)
feats = backbone.apply(variables, img[None, ...], mutable=False)[0]
return feats
@jax.jit
def gap_classifier_fn(variables, feats):
# feats: [h, w, c], float32
# logit: float32
# logits: [10], float32
logits = gap_classifier.apply(variables, feats[None, ...], mutable=False)[0]
logits = imagenet_to_imagenette_logits(logits)
return logits.max(), logits
ResNet, variables = jax_resnet.pretrained_resnet(size)
model = ResNet()
backbone = nn.Sequential(model.layers[:-2])
backbone_vars = jax_resnet.slice_variables(variables, start=0, end=-2)
gap_classifier = nn.Sequential(model.layers[-2:])
gap_classifier_vars = jax_resnet.slice_variables(variables, start=len(model.layers) - 2, end=None)
return (
flax.core.freeze({"backbone": backbone_fn, "gap_cls": gap_classifier_fn}),
flax.core.freeze({"backbone": backbone_vars, "gap_cls": gap_classifier_vars}),
)
grad_cam_fn = jax.jit(grad_cam_fn, static_argnames=["fns"])
grad_cam_fn = jax.vmap(grad_cam_fn, in_axes=(None, None, 0))
images, labels = viz_batch
fns, variables = load_resnet_for_grad_cam(size=18)
logits, relevance = grad_cam_fn(fns, variables, images)
show_images(
images * relevance[..., None],
labels,
logits,
)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
relevance: (7, 7)
Deletion score¶
So far, we have evaluated the explanations qualitatively by drawing them as heatmaps over a few sample images. A common metric used to evaluate the correctness of an explanation method quantitatively is the deletion score. It is computed by progressively removing pixels from an image in order of importance and measuring the corresponding drop in confidence. The behavior can be visualized on a plot that has the percentage of removed pixels on the horizontal axis and the prediction confidence on the vertical axis.
Ideally, if the most-relevant pixels are actually important for the prediction, their removal should induce a sudden drop in confidence for that label. To summarize this idea with a number we can compute the area under the curve: a low area indicates a quick decline in confidence, hence a good explanation method. This value is denoted as deletion score.
Single-image deletion score¶
Task 8¶
Implement a function prepare_deletion that given an image and the associated relevance prepares a batch of steps images.
If we indicate the resulting batch as imgs, not that:
imgs[0]corresponds to the input imageimgs[s]is a copy of the input image withs/(steps-1)percent of black pixelsimgs[-1]is an all-black image- pixels are set to zero in order of relevance with the most-relevant ones first
- if a pixel
(i, j)is set to black at stepsit will remain black in all subsequent steps, i.e. the images become progressively more black
The code below samples a random relevance mask and shows the images resulting from prepare_deletion so that you can verify the implementation.
Check that the first regions to become black are the ones with the highest relevance.
Tips:
- It's easier to reason about the flattened versions of the image and the relevance matrices, use
jnp.ndarray.flattenandjnp.unravel_indexto move back and forth between one and two dimensions - use
jnp.argsortandjnp.array_splitto sort and split the relevance, but be careful about sorting in ascending/descending order
def prepare_deletion(img, relevance, steps: int):
assert relevance.shape == img.shape[:2]
H, W, _ = img.shape
imgs = jnp.tile(img, (steps, 1, 1, 1))
#print(f"imgs:{imgs.shape}")
relevance = relevance.flatten()
indices = jnp.argsort(relevance)[::-1]
# chunks of index sections to mask for each step
prev_img = img
idx_sections = jnp.array_split(indices, steps-1)
for s in range(1, steps):
idxs = idx_sections[s-1]
i, j = jnp.unravel_index(idxs, (H,W))
curr_img = prev_img.at[i, j, :].set(0)
imgs = imgs.at[s].set(curr_img)
prev_img = curr_img
# imgs: [steps, H, W, 3]
return imgs
prepare_deletion = jax.jit(prepare_deletion, static_argnames="steps")
relevance = jax.random.uniform(jax.random.PRNGKey(42), (7, 7))
relevance = jax.image.resize(relevance, (224, 224), method="bilinear")
relevance = normalize_zero_one(relevance)
#print(f"relevance shape:{relevance.shape}")
#print(f"relevance:{relevance}")
image = plt.get_cmap('viridis')(relevance)[..., :3]
steps=8
images = prepare_deletion(image, relevance, steps)
assert images.shape == (steps, *image.shape)
show_images(images, width_one_img_inch=2)
Task 9¶
Using the prepare_deletion function implemented above, complete the missing lines of deletion_score_fn:
- The function takes as input an image, its relevance, and a number of steps
- The function returns a vector of length
stepscontaining the probabilities associated to the top-scoring class predicted by the model as more and more relevant pixels are removed - The function returns the original prediction
pred_origtoo
The cell below contains a few lines of code for plotting the resulting curve and the associated score. You should see the confidence curve slowly decreasing to zero and eventually rising up slighlty.
Tips:
- The expected value for the area under the curve is
0.270 - You don't need for loops, use
vmap
def deletion_score_fn(logits_fn, variables, img, relevance, steps):
H, W, _ = img.shape
# original prediction
_, logits_orig = logits_fn(variables, img)
pred_orig = logits_orig.argmax()
#jax.debug.print("pred_orig: {}", pred_orig)
probs_orig = nn.softmax(logits_orig)
prob_orig = probs_orig[logits_orig.argmax()]
#jax.debug.print("prob_orig: {}", prob_orig)
# imgs with deleted pixels
imgs = prepare_deletion(img, relevance, steps)
# get probs for each step
_, logits_all = jax.vmap(logits_fn, (None, 0))(variables, imgs)
probs_all = nn.softmax(logits_all, axis=-1)
#print(f"probs_all shape:{probs_all.shape}")
#jax.debug.print("probs_all: {}", probs_all)
probs = probs_all[:, pred_orig]
#jax.debug.print("probs: {}", probs)
# probs: [steps]
# pred_orig: int
return probs, pred_orig
deletion_score_fn = jax.jit(deletion_score_fn, static_argnames=["logits_fn", "steps"])
image = viz_batch[0][2]
relevance = jax.random.uniform(jax.random.PRNGKey(42), (7, 7))
relevance = jax.image.resize(relevance, image.shape[:2], method="bilinear")
relevance = normalize_zero_one(relevance)
steps = 8
logits_fn, variables = load_resnet(size=18)
probs, pred = deletion_score_fn(logits_fn, variables, image, relevance, steps)
auc = sklearn.metrics.auc(np.linspace(0,1,steps), probs)
assert len(probs) == steps
fig, ax = plt.subplots(1, 1, figsize=(9, 4))
ax.fill_between(np.linspace(0, 1, steps), probs)
ax.set_ylabel(f"Confidence for '{CLASS_NAMES[pred]}'")
ax.grid(axis="y")
ax.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))
ax.xaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))
ax.set_xlabel("Pixels removed")
ax.set_title(f"Deletion score: {auc:.3f}");
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
Visualize deletion on some images¶
The following code visualizes the explanations of all methods for the eight images used so far. On the right the deletion curve and its area is also plotted.
Task 10¶
No implementation required, just take a moment to compare the explanations:
- How do they look side-by-side? Does the deletion curve match what you see in the image?
- Which one do you trust the most? Motivate your feeling!
- Does your judgement correlate well with the deletion score?
- Is there a method that is consistently better?
Add your comments below:
- Yes, the deletion curves generally agree with the visual explanations. Smooth, focused heatmaps correspond to steep drops (better attribution). For example, in images (like gas pump) where the red region is concise and on the object (like Grad_CAM or IG), the confidence drops quickly — low deletion scores. Whereas, for church images, when the heatmap is noisy or off-target (like raw Grad), the curve fluctuates or stays high longer.
- Looking at the images, I trust Grad_CAM as it clearly focuses on the key semantic regions of each object (e.g., the center of the gas pump, body of the truck, or middle of the horn), making it easy to interpret visually. In addition, Integrated Gradients gives similar relevance but with smoother transitions and less noise than plain Grad or Grad×Input. Whereas, Occlusion is interpretable but coarse — it blocks large regions, hence losing spatial precision.
- Yes, Images with object-aligned heatmaps (IG) also have low deletion scores, meaning it correctly identified critical pixels. Methods with noisy or diffuse heatmaps (Grad, Grad×Input) show higher and fluctuating curves, confirming less stable explanations. Grad_CAM focuses on interpretable, high-level regions rather than pixel-exact importance. So, even though the heatmaps look convincing, they have high deletion scores as it does not focus on sharp, pixel-level sensitivity. Similarly, the deletion behavior shows that occlusion finds the right regions but not the exact pixels. Confidence drops slowly because each removed patch mixes relevant and irrelevant pixels — confirming that occlusion is reliable but coarse and less efficient compared to gradient-based methods.
- Yes, Integrated Gradients consistently shows the least deletion score, confirming quantitative reliability. Grad_CAM is consistently strong visually and gives interpretable, smooth heatmaps.
logits_fn, variables = load_resnet(size=18)
logits_fns_gc, variables_gc = load_resnet_for_grad_cam(size=18)
# Use lambdas instead of partial because vmap doesn't play well with kwargs
all_methods = {
"occlusion": lambda images: occlusion_fn(logits_fn, variables, images, 6),
"grad": lambda images: grad_norm_fn(logits_fn, variables, images),
"grad_x_input": lambda images: grad_x_input_fn(logits_fn, variables, images),
"grad_cam": lambda images: grad_cam_fn(logits_fns_gc, variables_gc, images),
"integrated_gradients": lambda images: integrated_grad_fn(logits_fn, variables, images, 20),
}
# These logits are only needed for visualization
images, labels = viz_batch
_, logits = jax.vmap(logits_fn, (None, 0))(variables, images)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0 Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
deletion_score_fn_vmap = jax.vmap(deletion_score_fn, in_axes=(None, None, 0, 0, None))
fig, axs = plt.subplots(
len(images),
len(all_methods) + 1,
figsize=((len(all_methods) + 3) * 3, len(images) * 3),
gridspec_kw={"width_ratios": len(all_methods) * [1] + [3], "wspace": 0.001},
)
# Write true/predicted class on the left side of each row
for ax, lb, lg in zip(axs[:, 0], labels, logits):
ax.set_ylabel(f"True {CLASS_NAMES[lb]}\nPred {CLASS_NAMES[lg.argmax()]}")
# Column headers: method names + deletion curve
for ax, method in zip(axs[0, :-1], all_methods.keys()):
ax.set_title(method)
axs[0, -1].set_title("Deletion curve")
# Each explanation method gets its own column on the left and its own curve on the right
for method_col, (method, method_fn) in zip(axs[:, :-1].T, all_methods.items()):
_, relevance = method_fn(images)
for ax, img, rel in zip(method_col, images, relevance):
ax.imshow(blend(img, RED, normalize_zero_one(rel)[..., None]))
probs, _ = deletion_score_fn_vmap(logits_fn, variables, images, relevance, 25)
for ax, p in zip(axs[:, -1], probs):
auc = sklearn.metrics.auc(np.linspace(0, 1, len(p)), p)
ax.plot(np.linspace(0, 1, len(p)), p, label=f"{auc:.3f} {method}")
# Remove inner ticks for the image grid
for ax in axs[:-1, :-1].flat:
ax.set_xticks([])
for ax in axs[:, 1:-1].flat:
ax.set_yticks([])
# Annotate right column on the rightmost edge
for ax in axs[:, -1]:
axt = ax.twinx()
axt.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))
axt.set_ylabel("Confidence")
axt.grid()
ax.set_yticks([])
ax.legend(loc="upper right", framealpha=1.0)
# Remove pixel percent ticks from right column, except at the bottom
for ax in axs[:-1, -1]:
ax.set_xticklabels([])
axs[-1, -1].xaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))
axs[-1, -1].set_xlabel("Pixels removed");
Average deletion score on entire dataset¶
To better compare the explanation methods, we can compute the average deletion score across the entire dataset. This value should give us an indication of which method is best at identifying relevant pixels for a prediction.
Task 11¶
No implementation required, just consider the results of this evaluation:
- Which method seems to be best?
- Can you trust results with such a high standard deviation?
- What can be the cause of it? Think both of how the metric is computed and of how the content of an image might affect the score.
Add your comments below:
- Integrated Gradients seems to the best with the least mean deletion score, indicating that Integrated Gradients seems to identify the most relevant pixels overall.
- The standard deviations are quite large, almost comparable with the mean value, meaning the deletion score varies a lot from one image to another and that we should be cautious when interpreting the ranking.
- One of the reasons for high standard deviation could be that the deletion score depends on how quickly the model’s confidence drops as pixels are removed. This can vary dramatically between images depending on the object’s size, location, and background complexity. Images where the object occupies most of the frame (e.g., a church filling the image) will show slow confidence decay, whereas small, localized objects (like a gas pump) will drop sharply.
Warning: the following code might take a long time to run and/or run out of memory. For reference, on a single GPU with 10 GB of memory and batch size 32, the total time for all the loops is approximately 10 minutes. You may need to reduce the batch size or limit the number of images.
deletion_steps = 10
avg_auc = []
deletion_steps_arr = np.linspace(0, 1, deletion_steps)
logits_fn, variables = load_resnet(size=18)
fns_gc, variables_gc = load_resnet_for_grad_cam(size=18)
ds = ds_builder.as_dataset(split="train", batch_size=None, as_supervised=True)
ds = ds.map(resize, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True)
ds = ds.batch(32, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True)
ds = ds.take(50) # limit number of batches
ds = ds.prefetch(4)
ds = tfds.as_numpy(ds)
for method, method_fn in all_methods.items():
aucs = []
for images, labels in tqdm.tqdm(ds, ncols=0, desc=method):
_, relevance = method_fn(images)
probs, _ = deletion_score_fn_vmap(logits_fn, variables, images, relevance, deletion_steps)
aucs.extend(sklearn.metrics.auc(deletion_steps_arr, p) for p in probs)
avg_auc.append({"method": method, "mean": np.mean(aucs), "std": np.std(aucs)})
avg_auc = pd.DataFrame(avg_auc)
display(avg_auc.set_index("method"))
fig, ax = plt.subplots(1, 1, figsize=(8, 4), facecolor="white")
avg_auc.plot(
"method",
"mean",
yerr="std",
kind="bar",
rot=0,
figsize=(10, 5),
legend=None,
ylim=(0, 1),
xlabel="",
title="Average deletion score (lower is better)",
ax=ax
)
display(fig)
plt.close(fig)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0 Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0 occlusion: 100% 50/50 [15:33<00:00, 19.52s/it]2025-10-07 14:24:25.507543: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence occlusion: 100% 50/50 [15:33<00:00, 18.67s/it] grad: 100% 50/50 [04:28<00:00, 5.36s/it]2025-10-07 14:28:53.893985: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence grad: 100% 50/50 [04:28<00:00, 5.37s/it] grad_x_input: 100% 50/50 [04:33<00:00, 5.42s/it]2025-10-07 14:33:27.296047: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence grad_x_input: 100% 50/50 [04:33<00:00, 5.47s/it] grad_cam: 100% 50/50 [03:59<00:00, 4.78s/it]2025-10-07 14:37:26.942168: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence grad_cam: 100% 50/50 [03:59<00:00, 4.79s/it] integrated_gradients: 100% 50/50 [41:13<00:00, 49.47s/it]2025-10-07 15:18:40.911704: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence integrated_gradients: 100% 50/50 [41:14<00:00, 49.48s/it]
| mean | std | |
|---|---|---|
| method | ||
| occlusion | 0.328927 | 0.227129 |
| grad | 0.257688 | 0.231876 |
| grad_x_input | 0.219738 | 0.185848 |
| grad_cam | 0.307005 | 0.206437 |
| integrated_gradients | 0.210839 | 0.185165 |
Conclusion¶
To help us improve this practical:
- How long did it take to complete this notebook? 27 hrs
- What was the most difficult part? Grad_CAM implementation took quite some time. And, the interpretation of each method.